# plotting HM orthlologs for both cases separately
# case 1:  fly ortholog does not exist
# case 2: fly ortholog exists, but is not expressed

import numpy as np
import scipy.sparse as sps
import pickle
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
plt.switch_backend('Agg')
from helperGraph import *
import igraph as ig
from scipy.stats import hypergeom
import os
import graph_tool.all as gt
import itertools

# processing networks
homedir = '/home/andrew'

# loading Dict
# loading in HGNC to Entrez conversion table
#hgncTab = pd.read_csv(homedir + '/NP/hgncToEntrez.txt', sep = "\t")
hgncTab = pd.read_csv(homedir + '/paper1/hgncEntrezTable.txt', sep = "\t")
entrezToHGNC = dict(zip(hgncTab['Entrez Gene ID(supplied by NCBI)'], hgncTab['Approved Symbol']))
HGNCToEntrez = dict(zip(hgncTab['Approved Symbol'],hgncTab['Entrez Gene ID(supplied by NCBI)']))


def entrezConv(gene):
	if gene in entrezToHGNC.keys():
		return(entrezToHGNC[gene])
	else:
		return('NA')

# STRING
stringFile = homedir + '/paper1/networks/9606.protein.links.detailed.v10.5.txt'
string_map_file = homedir + '/paper1/networks/entrez_gene_id.vs.string.v10.28042015.tsv'

stringTab = pd.read_csv(stringFile, sep = ' ')
stringFilt = stringTab[stringTab['combined_score'] > 700]

stringMap = pd.read_csv(string_map_file, sep ='\t')
stringDict = dict(zip(stringMap['STRING_Locus_ID'],stringMap['#Entrez_Gene_ID']))

stringGraph = nx.Graph()

for i,row in stringFilt.iterrows():
	if row['protein1'] in stringDict.keys() and row['protein2'] in stringDict.keys():
		stringGraph.add_edge(stringDict[row['protein1']],stringDict[row['protein2']], weight = 1)

# reading in BioPlex 2.0
bioplexFile = homedir + '/paper1/networks/BioPlex_interactionList_v4a.tsv'
bioplexTab = pd.read_csv(bioplexFile, sep = '\t')

bioplexGraph = nx.Graph()

for i,row in bioplexTab.iterrows():
	bioplexGraph.add_edge(row['GeneA'],row['GeneB'])

# reading inWeb
inwebFile = homedir + '/paper1/networks/inweb_hgnc.txt'
inwebTab = pd.read_csv(inwebFile,sep = '\t', header = None)
inwebTab.columns = ['GeneA','GeneB','Prob']
inwebGraph = nx.Graph()

for i,row in inwebTab.iterrows():
	if row['GeneA'] in HGNCToEntrez.keys() and row['GeneB'] in HGNCToEntrez.keys():
		inwebGraph.add_edge(HGNCToEntrez[row['GeneA']],HGNCToEntrez[row['GeneB']])



# reading humap
humapFile = homedir+ '/paper1/networks/humap_pairsWprob.txt'
humapTab = pd.read_csv(humapFile, sep = '\t', header = None)
humapTab.columns = ['GeneA','GeneB','Prob']

humapFilt = humapTab[humapTab['Prob'] > 0.5]

humapGraph = nx.Graph()

for i, row in humapFilt.iterrows():
	humapGraph.add_edge(row['GeneA'],row['GeneB'])




# interactor File
ints_cur_S = [line.rstrip('\n') for line in open(homedir + '/paper1/networks/interactors_curated.txt','r')]
ints_study_S = [line.rstrip('\n') for line in open(homedir + '/paper1/networks/interactors_thisstudy.txt','r')]

ints_cur = [int(HGNCToEntrez[gene]) for gene in ints_cur_S if gene in HGNCToEntrez.keys()]
ints_study = [int(HGNCToEntrez[gene]) for gene in ints_study_S if gene in HGNCToEntrez.keys()]
ints_both = list(set(ints_cur + ints_study))


# loading in modifiers
modifierFile = '/home/andrew/paper1/networks/tested_genes_modifiers_report_CHDI_2016.xls'
modTab1 = pd.read_excel(modifierFile,sheetname = 0)
modTab2 = pd.read_excel(modifierFile,sheetname = 1)


modMask1 = (modTab1[' Effect on HTT-NT231Q128Q Drosophila'] == 'E') | (modTab1[' Effect on HTT-NT231Q128Q Drosophila'] == 'S') | (modTab1[' Effect on HTT-FLQ200 Drosophila'] == 'E') | (modTab1[' Effect on HTT-FLQ200 Drosophila'] == 'S')
mods1 = list(modTab1[modMask1]['Human Gene'])
mods1_entrez = set([HGNCToEntrez[gene] for gene in mods1 if gene in HGNCToEntrez.keys()])


modMask2 = (modTab2.iloc[:,6] == 'E') | (modTab2.iloc[:,6] == 'S') | (modTab2.iloc[:,7] == 'E') | (modTab2.iloc[:,7] == 'S')
modsHuman = modTab2['Mouse / human  gene'].apply(lambda x: x.isupper())
mods2p = set(modTab2[modMask2]['Mouse / human  gene'])
mods2p_h = list(modTab2[modMask2 & modsHuman]['Mouse / human  gene'])

mods2p_hc = [HGNCToEntrez[gene] for gene in mods2p_h if gene in HGNCToEntrez.keys()]

# convert mods2p into all human
mouse2Human = pd.read_csv("/home/andrew/extData/hdinhd/orthoConv2.txt",sep = "\t")
mouse2Human2 = mouse2Human.dropna(subset = ['Gene ID.1'])
mouse2Human2['human'] = mouse2Human2['Gene ID.1'].apply(lambda x: int(x.split(';')[0]))
convDictMH = dict(zip(mouse2Human2['Gene ID'],mouse2Human2['human']))

mouseEntrezTab = pd.read_csv('/home/andrew/paper1/networks/MouseEntrez.txt', sep = '\t')

convertedMouseIDs = [convDictMH[gene] for gene in mouseEntrezTab['Entrez Gene ID'] if gene in convDictMH.keys()]

modsHumanHGNC = set(mods2p_hc + convertedMouseIDs)

combMods = mods1_entrez.union(modsHumanHGNC)

# case 1
# list input dir
geneListDir = '/home/andrew/paper1/NonOrthologsCase1_092617/'

resultDict = pickle.load(open(geneListDir + 'resultDict.p','rb'))
upKey = ('up', 'allCases', 'Neuron', 'case6')
downKey = ('down', 'allCases', 'Neuron', 'case6')


def getGeneLists(geneTab):
	geneTab_sub = geneTab[geneTab['Count'] > 0]
	geneTab_fly = list(itertools.chain(*geneTab_sub['Fly']))
	geneTab_mouse = list(itertools.chain(*geneTab_sub['Mouse']))
	geneTab_human = list(itertools.chain(*geneTab_sub['Human']))
	return(geneTab_fly,geneTab_mouse,geneTab_human)

humanDegs = {}
for key in resultDict.keys():
	(fly_degs_o, mouse_degs_o, human_degs_o) = getGeneLists(resultDict[key][1])
	humanDegs[key[0]] = human_degs_o



myNet = stringGraph.copy()

outDir = homedir + '/paper1/networks/STRING_NonOrthoCase1only_092617/'
if not os.path.exists(outDir):
    os.makedirs(outDir)

#loading in new combList


#intTab = pd.DataFrame(columns = ['NumberInNetwork','NumWStudy','PValInterWStudy',
#						 'NumWCur', 'PValInterWCur', 'NumWBoth', 'PValInterWBoth'], index = allLists.keys())

interWStudy = {}
interWCur = {}
interWBoth = {}

for name, myList in humanDegs.items():
	plt.close('all')

	myList = set(myList).difference(['NA'])
	subgraph = nx.subgraph(myNet, myList)

	# writing to graphml
	#nx.write_graphml(subgraph, homedir + '/paper1/networks/tmp.graphml')

	# loading to 
	#subgt = nx2gt(subgraph)
	#state = gt.minimize_blockmodel_dl(subgt, deg_corr=False)

	n = len(subgraph.nodes())
	integerNodes = [int(node) for node in subgraph.nodes()]
	g2 = ig.Graph.Adjacency((nx.to_numpy_matrix(subgraph) > 0).tolist())
	g2.vs['Entrez'] = subgraph.nodes()
	g2.vs['Symbol'] = [entrezToHGNC[gene] for gene in subgraph.nodes()]

	test = g2.community_infomap()

	conComps = sorted(nx.connected_components(subgraph), key = len, reverse=True)
	lenComps = [len(comp) for comp in list(conComps)]
	numComps = sum(np.array(lenComps) > 1)

	#attributes
	clustdict = dict(zip(subgraph.nodes(),test.membership))
	intStudyDict = {node : node in ints_study for node in subgraph.nodes()}
	intCurDict = {node : node in ints_cur for node in subgraph.nodes()}
	hgncDict = {node: entrezConv(node) for node in subgraph.nodes()}
	modDict = {node: node in combMods for node in subgraph.nodes()}

	nx.set_node_attributes(subgraph, 'clust', clustdict)
	nx.set_node_attributes(subgraph, 'intStudy', intStudyDict)
	nx.set_node_attributes(subgraph, 'intCur', intCurDict)
	nx.set_node_attributes(subgraph, 'hgnc', hgncDict)
	nx.set_node_attributes(subgraph, 'mod', modDict)



	# adjusting edge weights so that within community weights are higher
	maxmem = max(test.membership)+1

	for v1,v2 in subgraph.edges():
		if clustdict[v1] == clustdict[v1]:
			subgraph[v1][v2]['weight'] = 3



	#calculating various attributes

	intTab.loc[name,'NumberInNetwork'] = n
	intTab.loc[name,'NumWStudy'] = len(set(integerNodes) & set(ints_study))
	intTab.loc[name,'NumWCur'] = len(set(integerNodes) & set(ints_cur))
	intTab.loc[name,'NumWBoth'] = len(set(integerNodes) & set(ints_both))
	interWStudy[name] = set(integerNodes) & set(ints_study)
	interWCur[name] = set(integerNodes) & set(ints_cur)
	interWBoth[name] = set(integerNodes) & set(ints_both)

	# 1. Hypergeometric on interactors
	intTab.loc[name,'PValInterWStudy'] = hypergeom.pmf(intTab.loc[name,'NumWStudy'],len(myNet.nodes()),
														len(ints_study), n)
	intTab.loc[name,'PValInterWCur'] = hypergeom.pmf(intTab.loc[name,'NumWCur'],len(myNet.nodes()),
														len(ints_cur), n)
	intTab.loc[name,'PValInterWBoth'] = hypergeom.pmf(intTab.loc[name,'NumWBoth'],len(myNet.nodes()),
														len(ints_both), n)	
	
	nodesToPlot = []
	for i in range(numComps):
		nodesToPlot.extend(conComps[i])


	subgraph2 = nx.subgraph(subgraph,nodesToPlot)

	#saving subgraph and subgraph2 as graphml
	nx.write_graphml(subgraph, outDir + name +  "_subgraph.graphml")
	nx.write_graphml(subgraph2, outDir + name +  "_subgraph2.graphml")

	cmap = plt.get_cmap('gnuplot')
	colors = [clustdict[n] for n in subgraph2.nodes()]
	nodelabels = {gene : entrezToHGNC[gene] for gene in subgraph2.nodes()}
	communityDict = {gene : clustdict[gene] for gene in subgraph2.nodes()}
	intStudyDict2 = {node : node in ints_study for node in subgraph2.nodes()}
	modDict2 = {node : node in combMods for node in subgraph2.nodes()}

	#pos = community_layout(subgraph2, communityDict)
	A = nx.nx_agraph.to_agraph(subgraph2)
	pos = nx.nx_agraph.graphviz_layout(subgraph2,prog='neato',args ='-Elen=2')
	#pos = nx.spring_layout(subgraph2)
	#plt.use('Agg')

	pos_higher = {}
	y_off = 75  # offset on the y axis

	for k, v in pos.items():
	    pos_higher[k] = (v[0], v[1]+y_off)

	# converting attributes to size
	nodeSizes = [15 + 50*int(intStudyDict2[node]) for node in subgraph2.nodes()]


	plt.figure()
	nx.draw_networkx_nodes(subgraph2, nodelist = subgraph2.nodes(), 
			pos = pos, cmap =plt.get_cmap('prism'),node_color = list(colors),
			 with_labels = True, node_size=nodeSizes)
	nx.draw_networkx_edges(subgraph2,pos,width=1.0,alpha=0.5)
	nx.draw_networkx_labels(subgraph2,pos_higher,labels = nodelabels,font_size=4)

	plt.axis('off')
	plt.savefig(outDir + name + '.pdf')

	#outputting attributes
	with open(outDir + name + '_cluster.txt', 'w') as f:
		f.write('EntrezID\tCluster\tStudyInteractor\tCuratedInteractor\tModifier\n')
		for k in subgraph.nodes():
			f.write(str(int(k)) + '\t' + str(clustdict[k]) + '\t' + str(intStudyDict[k]) + '\t' + str(intCurDict[k]) + 
						'\t' + str(modDict[k])+'\n')



#######
#####






# case 2
# list input dir
geneListDir = '/home/andrew/paper1/NonOrthologs_081417_case2/'

resultDict = pickle.load(open(geneListDir + 'resultDict.p','rb'))
upKey = ('up', 'allCases', 'Neuron', 'case6')
downKey = ('down', 'allCases', 'Neuron', 'case6')


def getGeneLists(geneTab):
	geneTab_sub = geneTab[geneTab['Count'] > 0]
	geneTab_fly = list(itertools.chain(*geneTab_sub['Fly']))
	geneTab_mouse = list(itertools.chain(*geneTab_sub['Mouse']))
	geneTab_human = list(itertools.chain(*geneTab_sub['Human']))
	return(geneTab_fly,geneTab_mouse,geneTab_human)

humanDegs = {}
for key in resultDict.keys():
	(fly_degs_o, mouse_degs_o, human_degs_o) = getGeneLists(resultDict[key][1])
	humanDegs[key[0]] = human_degs_o



myNet = stringGraph.copy()

outDir = homedir + '/paper1/networks/STRING_NonOrthoCase1only_092617/'
if not os.path.exists(outDir):
    os.makedirs(outDir)

#loading in new combList


intTab = pd.DataFrame(columns = ['NumberInNetwork','NumWStudy','PValInterWStudy',
#						 'NumWCur', 'PValInterWCur', 'NumWBoth', 'PValInterWBoth'], index = allLists.keys())

interWStudy = {}
interWCur = {}
interWBoth = {}

for name, myList in humanDegs.items():
	plt.close('all')

	myList = set(myList).difference(['NA'])
	subgraph = nx.subgraph(myNet, myList)

	# writing to graphml
	#nx.write_graphml(subgraph, homedir + '/paper1/networks/tmp.graphml')

	# loading to 
	#subgt = nx2gt(subgraph)
	#state = gt.minimize_blockmodel_dl(subgt, deg_corr=False)

	n = len(subgraph.nodes())
	integerNodes = [int(node) for node in subgraph.nodes()]
	g2 = ig.Graph.Adjacency((nx.to_numpy_matrix(subgraph) > 0).tolist())
	g2.vs['Entrez'] = subgraph.nodes()
	g2.vs['Symbol'] = [entrezToHGNC[gene] for gene in subgraph.nodes()]

	test = g2.community_infomap()

	conComps = sorted(nx.connected_components(subgraph), key = len, reverse=True)
	lenComps = [len(comp) for comp in list(conComps)]
	numComps = sum(np.array(lenComps) > 1)

	#attributes
	clustdict = dict(zip(subgraph.nodes(),test.membership))
	intStudyDict = {node : node in ints_study for node in subgraph.nodes()}
	intCurDict = {node : node in ints_cur for node in subgraph.nodes()}
	hgncDict = {node: entrezConv(node) for node in subgraph.nodes()}
	modDict = {node: node in combMods for node in subgraph.nodes()}

	nx.set_node_attributes(subgraph, 'clust', clustdict)
	nx.set_node_attributes(subgraph, 'intStudy', intStudyDict)
	nx.set_node_attributes(subgraph, 'intCur', intCurDict)
	nx.set_node_attributes(subgraph, 'hgnc', hgncDict)
	nx.set_node_attributes(subgraph, 'mod', modDict)



	# adjusting edge weights so that within community weights are higher
	maxmem = max(test.membership)+1

	for v1,v2 in subgraph.edges():
		if clustdict[v1] == clustdict[v1]:
			subgraph[v1][v2]['weight'] = 3



	#calculating various attributes

	intTab.loc[name,'NumberInNetwork'] = n
	intTab.loc[name,'NumWStudy'] = len(set(integerNodes) & set(ints_study))
	intTab.loc[name,'NumWCur'] = len(set(integerNodes) & set(ints_cur))
	intTab.loc[name,'NumWBoth'] = len(set(integerNodes) & set(ints_both))
	interWStudy[name] = set(integerNodes) & set(ints_study)
	interWCur[name] = set(integerNodes) & set(ints_cur)
	interWBoth[name] = set(integerNodes) & set(ints_both)

	# 1. Hypergeometric on interactors
	intTab.loc[name,'PValInterWStudy'] = hypergeom.pmf(intTab.loc[name,'NumWStudy'],len(myNet.nodes()),
														len(ints_study), n)
	intTab.loc[name,'PValInterWCur'] = hypergeom.pmf(intTab.loc[name,'NumWCur'],len(myNet.nodes()),
														len(ints_cur), n)
	intTab.loc[name,'PValInterWBoth'] = hypergeom.pmf(intTab.loc[name,'NumWBoth'],len(myNet.nodes()),
														len(ints_both), n)	

	nodesToPlot = []
	for i in range(numComps):
		nodesToPlot.extend(conComps[i])


	subgraph2 = nx.subgraph(subgraph,nodesToPlot)

	#saving subgraph and subgraph2 as graphml
	nx.write_graphml(subgraph, outDir + name +  "_subgraph.graphml")
	nx.write_graphml(subgraph2, outDir + name +  "_subgraph2.graphml")

	cmap = plt.get_cmap('gnuplot')
	colors = [clustdict[n] for n in subgraph2.nodes()]
	nodelabels = {gene : entrezToHGNC[gene] for gene in subgraph2.nodes()}
	communityDict = {gene : clustdict[gene] for gene in subgraph2.nodes()}
	intStudyDict2 = {node : node in ints_study for node in subgraph2.nodes()}
	modDict2 = {node : node in combMods for node in subgraph2.nodes()}

	#pos = community_layout(subgraph2, communityDict)
	A = nx.nx_agraph.to_agraph(subgraph2)
	pos = nx.nx_agraph.graphviz_layout(subgraph2,prog='neato',args ='-Elen=2')
	#pos = nx.spring_layout(subgraph2)
	#plt.use('Agg')

	pos_higher = {}
	y_off = 75  # offset on the y axis

	for k, v in pos.items():
	    pos_higher[k] = (v[0], v[1]+y_off)

	# converting attributes to size
	nodeSizes = [15 + 50*int(intStudyDict2[node]) for node in subgraph2.nodes()]


	plt.figure()
	nx.draw_networkx_nodes(subgraph2, nodelist = subgraph2.nodes(), 
			pos = pos, cmap =plt.get_cmap('prism'),node_color = list(colors),
			 with_labels = True, node_size=nodeSizes)
	nx.draw_networkx_edges(subgraph2,pos,width=1.0,alpha=0.5)
	nx.draw_networkx_labels(subgraph2,pos_higher,labels = nodelabels,font_size=4)

	plt.axis('off')
	plt.savefig(outDir + name + '.pdf')

	#outputting attributes
	with open(outDir + name + '_cluster.txt', 'w') as f:
		f.write('EntrezID\tCluster\tStudyInteractor\tCuratedInteractor\tModifier\n')
		for k in subgraph.nodes():
			f.write(str(int(k)) + '\t' + str(clustdict[k]) + '\t' + str(intStudyDict[k]) + '\t' + str(intCurDict[k]) + 
						'\t' + str(modDict[k])+'\n')